% updated 07/07/2017

% model 1: Y = lam*A*Y + rho*Y_ttl + X*b + error
%% CONTROL PANEL
%clear
parentpath = cd(cd('..'));

warning('ON','msg:id')
dataset = 'cati_merged_sdc_rnd'; % Choose from: {'cati','sdc_rnd','sdc_rnd_ctt','cati_merged_sdc_rnd','cati_merged_sdc_rnd_ctt'}

rndlink = 0; % 1 if link duration is random, 0 if link duration is fixed
lnkyear = 5; % Choose from: {3,4,5,6,7}.
taxcred = 0; % 1 if use tax credit as IV for RND stock

%jaffepx = 1; % 1 if use jaffe proximity, 0 if use mahalanobis proximity
    
timelag = 1; % time lag of RND stock
sicdgit = 1; % Choose from: {1,10,100,1000}

fac = 1e3; % scale factor for output

%yr0 = 1962;
%yr1 = 1966;
yr0 = 1976;
yr1 = 1980;
yr2 = 2006;

% drop the first t0 years (i.e. the starting year is (yr1+t0))
t0 = 1;
%%
switch dataset
    case 'cati'
        n = 11093;
    case 'sdc_rnd'
        n = 12253;
    case 'sdc_rnd_ctt'
        n = 21387;
    case 'cati_merged_sdc_rnd'
        n = 19654;
    case 'cati_merged_sdc_rnd_ctt'
        n = 27147;
    otherwise
        disp('No appropriate dataset specified.')
end
T = yr2-yr1+1;
%%
ID = zeros(n,2);
ID(:,1) = (1:n)';
for s = 1:T
    yr = yr1+s-1;
    file = [parentpath '/' dataset '/balance_sheets/' int2str(yr) '_ID SIC SALES PROFIT RND EMPL FIXED_CAPITAL TOT_COST.csv'];
    flid = fopen(file);
    dat0 = textscan(flid,'%s %s %s %s %s %s %s %s','delimiter',';','headerlines',1);
    fclose(flid);
    %---------------------------------   
    FID = str2double(strtrim(dat0{1})); % FID
    SIC = str2double(strtrim(dat0{2})); % SIC
    FID(isnan(FID)) = 0;
    SIC(isnan(SIC)) = 0;
    SIC(floor(SIC/1000)==0) = 0; % drop firms with less than 4-digit SIC
    
    for i = 1:n
        if SIC(FID==i) > 0
            if SIC(FID==i) ~= ID(i,2)
                if ID(i,2) == 0
                    ID(i,2) = SIC(FID==i);
                else
                    ID(i,2) = -1;
                    %warning('msg:id','SIC code changes over time')
                end
            end
        end
    end
end
ID(:,2) = floor(ID(:,2)/sicdgit);
ID = ID(ID(:,2)>0,:); % drop firms with missing SIC
n = size(ID,1);

file = [parentpath '/' dataset '/location/' 'ID NATION_ID NATION CITY LOC_X LOC_Y.csv'];
flid = fopen(file);
dat0 = textscan(flid,'%s %s %s %s %s %s','delimiter',';','headerlines',1);
fclose(flid);
tmp1 = str2double(strtrim(dat0{1}));
tmp2 = str2double(strtrim(dat0{2}));

loc = zeros(size(ID,1),1);
for i = 1:n
    if sum(tmp1==ID(i,1)) == 1
        loc(i) = tmp2(tmp1==ID(i,1));
    end
end
ID = ID(loc==235,:);
n = size(ID,1);
%%
dm = cell(T,1);
X1 = cell(T,1);
X2 = cell(T,1);
X3 = cell(T,1);
for s = 1:T
    yr = yr1+s-1;
    %---------------------------------
    firm = zeros(n,3);
    %---------------------------------
    file = [parentpath '/' dataset '/balance_sheets/' int2str(yr) '_ID OUTPUT.csv'];
    flid = fopen(file);
    dat1 = textscan(flid,'%s %s','delimiter',';','headerlines',1);
    fclose(flid);
    fid1 = str2double(strtrim(dat1{1}));
    outp = str2double(strtrim(dat1{2}));
    %---------------------------------
    file = [parentpath '/' dataset '/balance_sheets/' int2str(yr-timelag) '_ID RND_STOCK.csv'];
    flid = fopen(file);
    dat2 = textscan(flid,'%s %s','delimiter',';','headerlines',1);
    fclose(flid);
    fid2 = str2double(strtrim(dat2{1}));
    stck = str2double(strtrim(dat2{2}));
    %---------------------------------
    file = [parentpath '/' dataset '/tax/bloom_firm_id_year_tax.csv'];
    flid = fopen(file);
    dat3 = textscan(flid,'%s %s %s','delimiter',',','headerlines',1);
    fclose(flid);
    tmp3 = str2double(strtrim(dat3{2}));
    dat3 = [str2double(strtrim(dat3{1})),str2double(strtrim(dat3{3}))];
    dat3 = dat3(tmp3==(yr-timelag),:);
    fid3 = dat3(:,1);
    taxc = dat3(:,2);
    %---------------------------------
    for i = 1:n
        if (sum(fid1==ID(i,1))==1)&&(sum(fid2==ID(i,1))==1)
            firm(i,1) = outp(fid1==ID(i,1))/fac;
            firm(i,2) = stck(fid2==ID(i,1));
            
            if (sum(fid3==ID(i,1))==1)
                firm(i,3) = taxc(fid3==ID(i,1));
            else
                firm(i,3) = 0;
            end
        end
    end
    firm(isnan(firm)) = 0;
    %---------------------------------
    file = [parentpath '/' dataset '/tot_output_compustat/' int2str(yr) '_Compustat_US_deflated_TOT_OUTPUT.csv'];
    flid  = fopen(file);
    dta1 = textscan(flid,'%s %s','delimiter',';','headerlines',1);
    fclose(flid);
    sec1 = [str2double(strtrim(dta1{1})),str2double(strtrim(dta1{2}))/fac];
    sec1(isnan(sec1)) = 0;
    sec1(:,1) = floor(sec1(:,1)/sicdgit);
    %---------------------------------
    file = [parentpath '/' dataset '/tot_rnd_stock/' int2str(yr-timelag) '_Compustat_US_TOT_RND_STOCK.csv'];
    flid  = fopen(file);
    dta2 = textscan(flid,'%s %s','delimiter',';','headerlines',1);
    fclose(flid);
    sec2 = [str2double(strtrim(dta2{1})),str2double(strtrim(dta2{2}))];
    sec2(isnan(sec2)) = 0;
    sec2(:,1) = floor(sec2(:,1)/sicdgit);
    %---------------------------------
    file = [parentpath '/' dataset '/tax/' int2str(yr-timelag) '_Bloom_US_TOT_TAX.csv'];
    flid  = fopen(file);
    dta3 = textscan(flid,'%s %s','delimiter',';','headerlines',1);
    fclose(flid);
    sec3 = [str2double(strtrim(dta3{1})),str2double(strtrim(dta3{2}))];
    sec3(isnan(sec3)) = 0;
    sec3(:,1) = floor(sec3(:,1)/sicdgit);
    %---------------------------------
    % add the sector-level data to the firm-level data
    aggr = zeros(n,3);
    for i = 1:n
        aggr(i,1) = sum(sec1(sec1(:,1)==ID(i,2),2))-firm(i,1);
        aggr(i,2) = sum(sec2(sec2(:,1)==ID(i,2),2))-firm(i,2);
        aggr(i,3) = sum(sec3(sec3(:,1)==ID(i,2),2))-firm(i,3);
    end
    %---------------------------------
    % load RND adjacency matrix
    if rndlink == 1
        file = [parentpath '/' dataset '/adjacency_matrix_rand_' num2str(lnkyear) '_years/' int2str(yr) '_adjacency_matrix.mat'];
    else
        file = [parentpath '/' dataset '/adjacency_matrix_' num2str(lnkyear) '_years/' int2str(yr) '_adjacency_matrix.mat'];
    end
    load(file)
    A = A(ID(:,1),:);
    A = A(:,ID(:,1));
    %---------------------------------
    % indicate firms with missing observations
    dm{s} = ones(n,1);
    for i = 1:n
        if (min(firm(i,1:2))<=0)||(min(aggr(i,1:3))<0)
            dm{s}(i) = 0;
        end
    end
    %---------------------------------
    % define variables
    X1{s} = firm;
    X2{s} = A*firm;
    X3{s} = aggr;
end
%%
% drop the first t0 year(s)
dt = cat(1,dm{:});
dt = dt(n*t0+1:end);
T1 = T-t0;
%---------------------------------
% drop firms that appear less than twice
di = (1:n*T1)';
dj = kron(ones(T1,1),(1:n)');
Dn = sparse(di,dj,dt);
dd = sum(Dn,1);
Dn = Dn(:,dd>1); % firms appearing more than once in the panel
np = size(Dn,2); % that number of firms

En = speye(n);
En = En(dd>1,:); % firms appearing more than once in the panel
ID = En*ID;
for s = 1:T
    dm{s} = En*dm{s};
    X1{s} = En*X1{s};
    X2{s} = En*X2{s};
    X3{s} = En*X3{s};
end
%% prediction of links
ID2 = floor(ID(:,2)/100);
grp = bsxfun(@eq,ID2,ID2');
grp(triu(true(size(grp))))=0;
[gi,gj,~] = find(grp);
%---------------------------------
nl = length(gi);
T2 = yr2-yr0+1;
Wy = zeros(nl,T2);
Wx = zeros(nl,T2);
Wa = zeros(nl,T2);
Wb = zeros(nl,T2);
Wp = zeros(nl,T2);
Wd = zeros(nl,T2);
for s = 1:T2
    yr = yr0+s-1;
    
    file = [parentpath '/' dataset '/adjacency_matrix_' num2str(lnkyear) '_years/' int2str(yr) '_adjacency_matrix.mat'];
    load(file)
    A = A(ID(:,1),:);
    A = A(:,ID(:,1));
    B = double(A^2>0)-double(A>0);
    %---------------------------------
    if jaffepx == 1
        file = [parentpath '/' dataset '/patents/' int2str(yr) '_jaffe_proximity.mat'];
        load(file)
    else
        file = [parentpath '/' dataset '/patents/' int2str(yr) '_mahalanobis_proximity.mat'];
        load(file)
        P = M;
    end
    P = P(ID(:,1),:);
    P = P(:,ID(:,1));
    
    file = [parentpath '/' dataset '/production_matrix/' int2str(yr) '_production_matrix.mat'];
    load(file)
    C = C(ID(:,1),:);
    C = C(:,ID(:,1));
    D = zeros(size(C));
    D(C+C'>0) = 1;
    %---------------------------------
    Wy(:,s) = A(grp==1);
    Wx(:,s) = B(grp==1);
    Wa(:,s) = sum(Wy,2);
    Wb(:,s) = sum(Wx,2);
    Wp(:,s) = P(grp==1);
    Wd(:,s) = D(grp==1);
end
%---------------------------------
file = [parentpath '/' dataset '/location/' 'ID NATION_ID NATION CITY LOC_X LOC_Y.csv'];
fid  = fopen(file);
dat0 = textscan(fid,'%s %s %s %s %s %s','delimiter',';','headerlines',1);
fclose(fid);
tmp1 = str2double(strtrim(dat0{1}));
tmp2 = str2double(strtrim(dat0{5}));
tmp3 = str2double(strtrim(dat0{6}));

loc = zeros(np,2);
for i = 1:np
    if sum(tmp1==ID(i,1)) == 1
        loc(i,1) = tmp2(tmp1==ID(i,1));
        loc(i,2) = tmp3(tmp1==ID(i,1));
    end
end
dis = pdist2(loc,loc);
prx = zeros(np);
prx(dis==0) = 1;
dc = prx(grp==1);
dc = kron(ones(T1,1),dc);
%---------------------------------
sec = floor(ID(:,2));
Ws1 = bsxfun(@eq,sec,sec');

sec = floor(ID(:,2)/10);
Ws2 = bsxfun(@eq,sec,sec');

ds1 = Ws1(grp==1);
ds2 = Ws2(grp==1);

ds1 = kron(ones(T1,1),ds1);
ds2 = kron(ones(T1,1),ds2);
%---------------------------------
n0 = nl*T1;
dy = reshape(Wy(:,T2-T1+1:T2),n0,1);
da = reshape(Wa(:,T2-T1+1-lnkyear:T2-lnkyear),n0,1);
db = reshape(Wb(:,T2-T1+1-lnkyear:T2-lnkyear),n0,1);
dp = reshape(Wp(:,T2-T1+1-lnkyear:T2-lnkyear),n0,1);
d2 = reshape(Wd(:,T2-T1+1-lnkyear:T2-lnkyear),n0,1);
%da = reshape(Wa(:,T2-T1-lnkyear:T2-lnkyear-1),n0,1);
%db = reshape(Wb(:,T2-T1-lnkyear:T2-lnkyear-1),n0,1);
%dp = reshape(Wp(:,T2-T1-lnkyear:T2-lnkyear-1),n0,1);

dx = [da,db,dp,dp.^2,d2,dc,ds1,ones(n0,1)];
kx = size(dx,2);

temp = [dy,da,db,dp,dp.^2,d2,dc,ds1];
name = cell(1,size(temp,2));
name{1} = 'y';
for j = 1:size(temp,2)-1
    name{j+1} = ['z' num2str(j)];
end

filename = ['stata_data1\linkdata_io_' num2str(jaffepx) '.csv'];
delete(filename);
csvwrite_with_headers(filename,temp,name);
%---------------------------------
xx = dx'*dx;
b0 = xx\(dx'*dy);
option = optimoptions(@fminunc,'Algorithm','trust-region','GradObj','on','Hessian','on','Display','notify','DerivativeCheck','off');
[b0,FVAL,EXITFLAG,OUTPUT,GRAD,HESSIAN] = fminunc('logit_obj',b0,option,dy,dx);
s0 = sqrt(diag(HESSIAN\speye(kx)))';
p0 = exp(dx*b0)./(1+exp(dx*b0));
%---------------------------------
d0 = 0;
for i = 1:n0
    d0 = d0+dy(i)*log(p0(i))+(1-dy(i))*log(1-p0(i));
end
y0 = mean(dy);
d1 = n0*(y0*log(y0)+(1-y0)*log(1-y0));
R2 = 1-d0/d1; % McFadden's R-squared (Cameron and Trivedi, p.474)
%http://www.ats.ucla.edu/stat/mult_pkg/faq/general/Psuedo_RSquareds.htm
%---------------------------------
for s = 1:T1
    Ap = sparse(gi,gj,p0(nl*(s-1)+1:nl*s),np,np);
    Ap = Ap+Ap';
    X2{t0+s}(:,2:3) = Ap*X1{t0+s}(:,2:3);
end
%---------------------------------
fid = fopen('stata_data1\link_prediction_io_output.txt','a');
fprintf(fid,'%s \n',date);
fprintf(fid,'# of links is %i \n',n0);
if jaffepx == 1
    fprintf(fid,'Use jaffe proximity to predict links \n');
else
    fprintf(fid,'Use mahalanobis proximity to predict links \n');
end

fprintf(fid,'link predication R^2 is %7.4f \n',R2);
for ip = 1:kx
    tstat = b0(ip)/s0(ip);
    tstat = abs(tstat);
    if tstat >= 2.326
        fprintf(fid,' %7.4f***\n (%6.4f)\n',b0(ip),s0(ip));
    elseif tstat >= 1.96
        fprintf(fid,' %7.4f**\n (%6.4f)\n',b0(ip),s0(ip));
    elseif tstat >= 1.645
        fprintf(fid,' %7.4f*\n (%6.4f)\n',b0(ip),s0(ip));
    else
        fprintf(fid,' %7.4f\n (%6.4f)\n',b0(ip),s0(ip));
    end
end
fclose(fid);
%%
D1 = cell(T1,1);
D2 = cell(T1,1);
Yn = cell(T1,1);
Zn = cell(T1,1);
Qn = cell(T1,1);
pid = cell(T1,1);
tid = cell(T1,1);
for s = (t0+1):T
    Da = diag(dm{s});
    Da = Da(sum(Da,2)==1,:);
    Db = zeros(size(Da,1),T1);
    Db(:,s-t0) = 1;
    D1{s-t0} = Da;
    D2{s-t0} = Db;    
    
    Yn{s-t0} = Da*X1{s}(:,1);
    Zn{s-t0} = Da*[X2{s}(:,1),X3{s}(:,1),X1{s}(:,2)];
    if taxcred == 1
        Qn{s-t0} = Da*[X2{s}(:,3),X3{s}(:,3),X1{s}(:,3)];
    else
        Qn{s-t0} = Da*[X2{s}(:,2),X3{s}(:,2),X1{s}(:,2)];
    end
    pid{s-t0} = Da*ID(:,1);
    tid{s-t0} = (s-t0)*ones(size(Da,1),1);
end

D1 = cat(1,D1{:});
D2 = cat(1,D2{:});
Yn = cat(1,Yn{:});
Zn = cat(1,Zn{:});
Qn = cat(1,Qn{:});

temp = [cat(1,pid{:}),cat(1,tid{:}),Yn,Zn,Qn];
name = cell(1,size(temp,2));
name{1} = 'pid';
name{2} = 'tid';
name{3} = 'y';
for j = 1:size(Zn,2)
    name{j+3} = ['z' num2str(j)];
end
for j = 1:size(Qn,2)
    name{j+3+size(Zn,2)} = ['q' num2str(j)];
end

filename = ['stata_data1\data2_io_' num2str(jaffepx) '.csv'];
delete(filename);
csvwrite_with_headers(filename,temp,name);